import graphlearning as gl
import time
import numpy as np
import method
import time


# D = knn_dist*knn_dist
# eps = D[:,k-1]
# weights = eta(D/eps[:,None])


def gaussian_eta(a=0.25, b=0.0):
    return lambda x: np.exp(-x/a) - np.exp(-x/a).min(1, keepdims=True) + b


# eta = lambda x: np.exp(-4 * x) - np.exp(-4.5)
# eta = lambda x: np.max(x) - x + 2
# eta = lambda x: np.max(x, axis=-1, keepdims=True) - x + 1 # in the paper
# eta = lambda x: 4 - x
# eta = None
# eta = lambda x: np.exp(-x / 4) + 2

# eta = None

dataset = 'mnist'
metric = 'raw' if dataset in ['cifar10', 'cifar100', 'svhn'] else 'vae'
dataset = 'cifar'
metric = 'aet'

# np.random.seed(2)

labels = gl.datasets.load(dataset, labels_only=True)
num_train_per_class = 1
train_ind = gl.trainsets.generate(labels, rate=num_train_per_class)
train_labels = labels[train_ind]

# for b in np.linspace(0, np.exp(-4), 10):
for b in [0.0, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.5]:
    eta = gaussian_eta(b=b)
    W = gl.weightmatrix.knn(dataset, 10, kernel='gaussian', eta=eta, metric=metric)

    # np.random.seed(1

    try:
        import torch
        torch.cuda.set_device(0)
        use_cuda = True
    except:
        use_cuda = False

    class_priors = np.ones(10) / 10
    # lambs = [0.0, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0]
    # models = [method.v_lapalce(W, lamb=lamb, use_cuda=use_cuda) for lamb in lambs]
    # models += [method.v_poisson(W, lamb=lamb, use_cuda=use_cuda) for lamb in lambs]
    # models = [gl.ssl.poisson(W), gl.ssl.poisson_mbo(W, class_priors=class_priors)]
    # models = [gl.ssl.poisson(W, solver='gradient_descent')]
    # models = [method.v_poisson(W, var='weighted', mode='accum', lamb=b)]
    # models = [gl.ssl.laplace(W)]
    models = [gl.ssl.plaplace(W)]
    # models = [gl.ssl.poisson(W)]
    # models = [gl.ssl.poisson_mbo(W, solver='gradient_descent', use_cuda=use_cuda, class_priors=class_priors)] + models + [method.v_poisson_mbo(W, lamb=lamb, use_cuda=use_cuda, class_priors=class_priors) for lamb in lambs]

    end = time.time()
    for model in models:
        pred_labels = model.fit_predict(train_ind, train_labels, all_labels=None)
        accuracy = gl.ssl.ssl_accuracy(labels, pred_labels, len(train_ind))
        print(model.name + ': %.2f%%' % accuracy, 'with ', time.time() - end)
        end = time.time()